import sigpy as sp
import sigpy.mri as mr
import sigpy.plot as pl
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from utils import undersample
from skimage import metrics
from sense import sense

# Load image and coil sensitivity data
data = sio.loadmat('brain_8ch.mat')
im = np.transpose(data['im'], (2, 0, 1))     # (8, 160, 220) Image
map = np.transpose(data['map'], (2, 0, 1))   # (8, 160, 220) Coil Sensitivities

im_p = pl.ImagePlot(im, z=0, hide_axes=True)      # Images
map_p = pl.ImagePlot(map, z=0, hide_axes=True)    # Coil Sensitivities

# Undersampled image
imu = undersample(im, 1, 2)
imu_p = pl.ImagePlot(imu, z=0, hide_axes=True)    # Undersampled image Rx = 2

# k-space data
ksp = sp.fft(im, axes=(1,2))        # Fully sampled kspace (8, 160, 220)
kspu = sp.fft(imu, axes=(1,2))      # 2x undersampled in ky (8, 160, 220)

"""
Solve inverse problem to get GRAPPA kernel
Assumes undersampling in x and a (2x3) kernel
"""
def grappa_kernel(ksp, cal_width):
    nc, nkx, nky = ksp.shape
    kspcal = np.transpose(ksp[:, (nkx - cal_width)//2:(nkx + cal_width)//2, :], (1, 2, 0))   # kspace calibration data (ncal, 220, 8)
    xpos = np.array([-1,1])
    xpad = len(xpos) // 2
    ypos = np.array([-2,-1,0,1,2])
    ypad = len(ypos) // 2
    kspcal_padded = np.pad(kspcal, ((xpad, xpad), (ypad, ypad), (0, 0)), 'constant', constant_values=0)

    # Form the forward model and measurement matrices
    A = np.zeros([cal_width, nky, len(xpos), len(ypos), nc], dtype=kspcal.dtype)
    B = np.zeros([cal_width, nky, nc], dtype=kspcal.dtype)
    for nx in range(cal_width):
        for ny in range(nky):
            A[nx, ny, ...] = kspcal_padded[(nx + xpos + xpad)[:, None], (ny + ypos + ypad), ...]
            B[nx, ny, ...] = kspcal_padded[nx + xpad, ny + ypad]

    # Estimate kernel weights in least squares fashion
    A = np.reshape(A, [cal_width * nky, len(xpos) * len(ypos) * nc])
    B = np.reshape(B, [cal_width * nky, nc])
    X = np.linalg.pinv(A) @ B
    grappa_kernel = np.reshape(X, (len(xpos), len(ypos), nc, nc))
    grappa_kernel = np.insert(grappa_kernel, xpad, 0, axis=0)
    grappa_kernel = np.transpose(grappa_kernel, [3, 2, 0, 1])   # GRAPPA kernel (nc, nc, xpos, ypos)
    return grappa_kernel

# Test grappa_kernel on brain_mat Rx = 2, Ry = 1
grappa_kernel = np.flip(grappa_kernel(ksp, 20), (-2, -1))
ksp_grappa = np.copy(kspu)
for coil_i in range(im.shape[0]):
    for coil_j in range(im.shape[0]):
        ksp_grappa[coil_i, ...] += signal.convolve2d(kspu[coil_j, ...], grappa_kernel[coil_i, coil_j, ...], mode='same')
im_grappa = sp.ifft(ksp_grappa, axes=(1, 2))    # GRAPPA-reconstructed parallel image

# Combine coil images into one using sum of squares
im_grappa_ss = np.zeros((im_grappa.shape[1], im_grappa.shape[2]), dtype=complex)
im_ss = np.zeros((im.shape[1], im.shape[2]), dtype=complex)
for m in range(im_grappa.shape[1]):
    for n in range(im_grappa.shape[2]):
        reshaped = im_grappa[:, m, n].reshape(8, -1)
        im_grappa_ss[m, n] = np.sqrt(np.sum(reshaped * reshaped))

        reshaped_gt = im[:, m, n].reshape(8, -1)
        im_ss[m, n] = np.sqrt(np.sum(reshaped_gt * reshaped_gt))
plt.imshow(np.abs(im_grappa_ss), cmap='gray')
plt.savefig("grappa_recon.png")
plt.imshow(np.abs(im_ss), cmap='gray')
plt.savefig("ground_truth.png")

# Calculate pSNR of R=2 GRAPPA reconstructed image
data_range = np.max(np.abs(im_grappa_ss)) - np.min(np.abs(im_grappa_ss))
pSNR = metrics.peak_signal_noise_ratio(np.abs(im_ss), np.abs(im_grappa_ss), data_range=data_range)
print("pSNR for GRAPPA is " + str(pSNR) + " dB")